# ======================================================================
# Copyright TOTAL / CERFACS / LIRMM (04/2020)
# Contributor: Adrien Suau (<adrien.suau@cerfacs.fr>
#                           <adrien.suau@lirmm.fr>)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================
import argparse
from pathlib import Path
import sys
import pickle

import numpy
import matplotlib.pyplot as plt

from qaths.applications.wave_equation.evolve_1D_dirichlet import (
    solve_1d_dirichlet_stationary,
)


def main():

    parser = argparse.ArgumentParser(
        description=(
            "Plot (or save) the solution of the wave "
            "equation as computed by the solver."
        )
    )
    parser.add_argument(
        "discretisation_size", type=int, help="Number of discretisation points to use"
    )
    parser.add_argument("evolution_time", type=float, help="Physical time of evolution")
    parser.add_argument("epsilon", type=float, help="Desired precision of the result")
    parser.add_argument(
        "-t",
        "--trotter-order",
        type=int,
        default=1,
        help="Order of the Trotter formula to use",
    )
    parser.add_argument(
        "-o",
        "--output",
        type=Path,
        help=(
            "if set, the result will be saved in the file given and not plot "
            'happen. Cannot be used with "-i" or "--input"'
        ),
        required=False,
    )
    parser.add_argument(
        "-i",
        "--input",
        type=Path,
        help=(
            "if set, the result will be loaded from the given file and plotted. "
            'Cannot be used with "-o" or "--output"'
        ),
        required=False,
    )
    parser.add_argument(
        "--probability-threshold",
        "--pt",
        type=float,
        help=(
            "threshold used to filter out outcomes that are due to numerical errors "
            "in the simulator"
        ),
        default=1e-8,
    )
    parser.add_argument(
        "--imaginary-part-threshold",
        "--it",
        type=float,
        help=(
            "threshold used to assert that the 2-norm of the imaginary part of the "
            "returned solution is not too high"
        ),
        default=1e-10,
    )
    args = parser.parse_args()
    if args.input is not None and args.output is not None:
        print(
            'You should NOT use "-i" or "--input" with "-o" or "--output"',
            file=sys.stderr,
        )
        exit(1)

    X = numpy.linspace(0, 1, args.discretisation_size)
    if args.input is not None:
        print(f"Using solution from '{args.input.absolute()}'.")
        with open(args.input, "rb") as f:
            solution = pickle.load(f)
    else:
        initial_state = numpy.sin(2 * numpy.pi * X, dtype=numpy.float)

        solution = solve_1d_dirichlet_stationary(
            args.evolution_time,
            args.discretisation_size,
            args.epsilon,
            initial_state,
            trotter_order=args.trotter_order,
            probability_threshold=args.probability_threshold,
            imaginary_part_threshold=args.imaginary_part_threshold,
        )
        print(solution)
    if args.output is not None:
        with open(args.output, "wb") as f:
            pickle.dump(solution, f)
        print(f"Solution saved in '{args.output.absolute()}'.")
    else:
        # Plot the resulting solution
        plt.plot(X, solution)
        plt.show()
